import json
import os
from typing import Annotated

from langchain_community.chat_models import ChatZhipuAI
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_core.messages import ToolMessage
from langgraph.graph import START, END, StateGraph
from langgraph.graph.message import add_messages
from typing_extensions import TypedDict

os.environ["ZHIPUAI_API_KEY"] = "97738d4998b8732d707daf91a2b1c56d.2y6VKEuOlidwHDpI"
os.environ["TAVILY_API_KEY"] = "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"


class BasicToolNode:
    def __init__(self, tools: list):
        self.tools_by_name = {tool.name: tool for tool in tools}

    def __call__(self, input: dict):
        if messages := input.get("messages", []):
            message = messages[-1]
        else:
            raise ValueError("No message found in input")
        outputs = []
        for tool_call in message.tool_calls:
            tool_result = self.tools_by_name[tool_call["name"]].invoke(
                tool_call["args"]
            )
            outputs.append(
                ToolMessage(
                    content=json.dumps(tool_result),
                    name=tool_call["name"],
                    tool_call_id=tool_call["id"],
                )
            )
        return {"messages": outputs}


tool = TavilySearchResults(max_results=2)
tool_node = BasicToolNode(tools=[tool])
tools = [tool]


class State(TypedDict):
    messages: Annotated[list, add_messages]


graph_builder = StateGraph(State)

llm = ChatZhipuAI(
    model="glm-4",
    temperature=0.5,
).bind_tools(tools)


def chatbot(state: State):
    return {"messages": [llm.invoke(state["messages"])]}


def route_tools(state: State):
    if isinstance(state, list):
        ai_message = state[-1]
    elif messages := state.get("messages", []):
        ai_message = messages[-1]
    else:
        raise ValueError(f"No messages found in input state to tool_edge: {state}")
    if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
        return "tools"
    return END


graph_builder.add_node("chatbot", chatbot)
graph_builder.add_node("tools", tool_node)
graph_builder.add_edge(START, "chatbot")
graph_builder.add_edge("tools", "chatbot")
# graph_builder.add_edge("chatbot", END)
graph_builder.add_conditional_edges("chatbot", route_tools, {"tools": "tools", END: END})

graph = graph_builder.compile()


def stream_graph_updates(user_input: str):
    for event in graph.stream({"messages": [("user", user_input)]}):
        for value in event.values():
            print("Assistant:", value["messages"][-1].content)


while True:
    user_input = input("User:")
    if user_input.lower() in ['quit', 'exit', 'q']:
        print("Goodbye!")
        break
    stream_graph_updates(user_input)
